/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Defines basic properties needed by CTA-level batched GEMV assuming
   expectations about data layout of the global memory fragments, data types,
   and internal tile sizes.

      Partial specializations for threadblock::Mma operations targeting SIMT
   instructions.
*/

#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"

#include "cutlass/layout/matrix.h"

#include "cutlass/platform/platform.h"

#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/thread/mma.h"

#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "cutlass/transform/pitch_linear_thread_map.h"

#include "cutlass/gemm/threadblock/gemv.h"

/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Template defininng default vector-matrix multiply operators inferred from
/// threadblock tile size, global memory data layout.
template <typename Shape_,  /// Shape of the threadblock vector-matrix multiply
                            /// operator
          typename ThreadShape_,  /// Shape of per-thread vector-matrix multiply
                                  /// operator
          typename ElementA_,     /// Element data type of A operand
          typename LayoutA_,      /// Layout of operand A
          typename ElementB_,     /// Element data type of B operand
          typename LayoutB_,      /// Layout of operand B
          typename ElementC_,     /// Data type of accumulator
          typename LayoutC_       /// Layout of accumulator
          >
struct DefaultGemvCore {
    using Shape = Shape_;
    using ThreadShape = ThreadShape_;

    using LayoutA = LayoutA_;
    using LayoutB = LayoutB_;
    using LayoutC = LayoutC_;

    using ElementA = ElementA_;
    using ElementB = ElementB_;
    using ElementC = ElementC_;

    static int const kThreadsPerN = Shape::kN / ThreadShape::kN;

    using IteratorPolicyA = typename platform::conditional<
            platform::is_same<LayoutA, layout::RowMajor>::value,
            cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous<
                    layout::PitchLinearShape<Shape::kK, Shape::kM>, 1,
                    ThreadShape::kK>,
            cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided<
                    layout::PitchLinearShape<Shape::kM, Shape::kK>, 1,
                    ThreadShape::kM>>::type;

    using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
            cutlass::MatrixShape<Shape::kM, Shape::kK>, ElementA, LayoutA, 1,
            IteratorPolicyA>;

    using IteratorPolicyB = typename platform::conditional<
            platform::is_same<LayoutB, layout::RowMajor>::value,
            cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous<
                    layout::PitchLinearShape<Shape::kN, Shape::kK>,
                    kThreadsPerN, ThreadShape::kN>,
            cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided<
                    layout::PitchLinearShape<Shape::kK, Shape::kN>,
                    kThreadsPerN, ThreadShape::kK>>::type;

    using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
            cutlass::MatrixShape<Shape::kK, Shape::kN>, ElementB, LayoutB, 0,
            IteratorPolicyB>;

    using IteratorPolicyC = typename platform::conditional<
            platform::is_same<LayoutC, layout::RowMajor>::value,
            cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous<
                    layout::PitchLinearShape<Shape::kN, Shape::kM>,
                    kThreadsPerN, ThreadShape::kN>,
            cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided<
                    layout::PitchLinearShape<Shape::kM, Shape::kN>,
                    kThreadsPerN, ThreadShape::kM>>::type;

    using IteratorC = cutlass::transform::threadblock::PredicatedTileIterator<
            cutlass::MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC, 0,
            IteratorPolicyC>;

    using MmaSimtOp = typename cutlass::gemm::thread::Mma<
            cutlass::gemm::GemmShape<ThreadShape::kM, ThreadShape::kN,
                                     Shape::kK>,
            ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC>;

    using Operator = MmaSimtOp;

    // Assertions for correctness
    static_assert((Shape::kM == 1), "M=1 is required for GEMV");

    static_assert((ThreadShape::kM == 1), "M=1 is required for GEMV");

    static_assert(Shape::kK % ThreadShape::kK == 0,
                  "Shape::K must be a multiple of ThreadShape::K");

    static_assert(((ThreadShape::kK == 1) || (ThreadShape::kK == 2) ||
                   (ThreadShape::kK == 4) || (ThreadShape::kK == 8) ||
                   (ThreadShape::kK == 16) || (ThreadShape::kK == 32)),
                  "ThreadShape::K must be a 1, 2, 4, 8, 16 or 32");
};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Partial specialization for 
/// LayoutA: cutlass::layout::RowMajor
/// LayoutB: cutlass::layout::RowMajor
/// LayoutC: cutlass::layout::RowMajor
template <typename Shape_,  /// Shape of the threadblock vector-matrix multiply
                            /// operator
          typename ThreadShape_,  /// Shape of per-thread vector-matrix multiply
                                  /// operator
          typename ElementA_,     /// Element data type of A operand
          typename ElementB_,     /// Element data type of B operand
          typename ElementC_>     /// Data type of accumulator
struct DefaultGemvCore<
        Shape_, ThreadShape_, ElementA_, cutlass::layout::RowMajor, ElementB_,
        cutlass::layout::RowMajor, ElementC_, cutlass::layout::RowMajor> {
    using Shape = Shape_;
    using ThreadShape = ThreadShape_;

    using LayoutA = cutlass::layout::RowMajor;
    using LayoutB = cutlass::layout::RowMajor;
    using LayoutC = cutlass::layout::RowMajor;

    using ElementA = ElementA_;
    using ElementB = ElementB_;
    using ElementC = ElementC_;

    static int const kThreadsPerN = Shape::kN / ThreadShape::kN;
    static int const kThreadsPerK = Shape::kK / ThreadShape::kK;
    using ThreadArrangement =
            cutlass::layout::PitchLinearShape<kThreadsPerN, kThreadsPerK>;

    using IteratorPolicyA =
            cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous<
                    layout::PitchLinearShape<ThreadShape::kK, Shape::kM>, 1,
                    ThreadShape::kK>;

    using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
            cutlass::MatrixShape<Shape::kM, Shape::kK>, ElementA, LayoutA, 1,
            IteratorPolicyA>;

    using IteratorPolicyB = cutlass::transform::
            PitchLinear2DTilePolicyStripminedThreadContiguous<
                    layout::PitchLinearShape<Shape::kN, Shape::kK>,
                    ThreadArrangement, ThreadShape::kN>;

    using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
            cutlass::MatrixShape<Shape::kK, Shape::kN>, ElementB, LayoutB, 0,
            IteratorPolicyB>;

    using IteratorPolicyC = typename platform::conditional<
            platform::is_same<LayoutC, layout::RowMajor>::value,
            cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous<
                    layout::PitchLinearShape<Shape::kN, Shape::kM>,
                    kThreadsPerN, ThreadShape::kN>,
            cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided<
                    layout::PitchLinearShape<Shape::kM, Shape::kN>,
                    kThreadsPerN, ThreadShape::kM>>::type;

    using IteratorC = cutlass::transform::threadblock::PredicatedTileIterator<
            cutlass::MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC, 0,
            IteratorPolicyC>;

    using MmaSimtOp = typename cutlass::gemm::thread::Mma<
            cutlass::gemm::GemmShape<ThreadShape::kM, ThreadShape::kN,
                                     ThreadShape::kK>,
            ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC>;

    using Operator = MmaSimtOp;

    // Assertions for correctness
    static_assert((Shape::kM == 1), "M=1 is required for GEMV");

    static_assert((ThreadShape::kM == 1), "M=1 is required for GEMV");

    static_assert(Shape::kK % ThreadShape::kK == 0,
                  "Shape::K must be a multiple of ThreadShape::K");

    static_assert(((ThreadShape::kK == 1) || (ThreadShape::kK == 2) ||
                   (ThreadShape::kK == 4) || (ThreadShape::kK == 8) ||
                   (ThreadShape::kK == 16) || (ThreadShape::kK == 32)),
                  "ThreadShape::K must be a 1, 2, 4, 8, 16 or 32");
};

/////////////////////////////////////////////////////////////////////////////////////////////////


}  // namespace threadblock
}  // namespace gemm
}  // namespace cutlass
